{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7a48fe19-1b23-41b8-928f-0eb9636c74e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ebc19799-df61-4bde-827b-34e48050d6ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "abbf0caa-e0fa-4491-b8dd-598ff1047432",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x749bd40c6cf0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc48f883-570a-4c14-828f-a0852eb21ed1",
   "metadata": {},
   "source": [
    "- 反向传播，链式法则，内部非叶子节点（non-leaf node，哪怕 requires_grad 为 true，且其存在 grad_fn）也是会算梯度的，只是用完就置空了，\n",
    "    - 因此如果相查看内部非叶子节点的 grad，需要 retain_graph 保留在计算图中;\n",
    "- 深度神经网络中的中间层 layer 的参数（weights & bias）它们是内部节点呢，还是叶子节点呢？\n",
    "    - 是叶子节点；\n",
    "- 不要轻易地关闭 warnings，有助于排查/定位问题；\n",
    "    - warnings 不会导致程序 dump，但不推荐，因为有可能导致程序的运行不符合预期；\n",
    "    - 对于自己写的代码，出于健壮性或者可快速定位问题的考虑，也可以尝试多写 warnings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61f5483d-e147-4c95-9be3-f3d24cbee19a",
   "metadata": {},
   "source": [
    "## multi head (output/branch) architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d476bffa-a357-4764-8d28-b3ae9d9cd081",
   "metadata": {},
   "source": [
    "- https://www.bilibili.com/video/BV1o24y1b7tk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "61e6d2b3-e605-472c-ae9a-ae59919d6637",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"../imgs/multi_loss.PNG\" width=\"100\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image(url='../imgs/multi_loss.PNG', width=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6060b30f-45c3-4d19-bddd-39206bedbcbf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 12\u001b[0m\n\u001b[1;32m      9\u001b[0m d\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m     11\u001b[0m \u001b[38;5;66;03m# RuntimeError: Trying to backward through the graph a second time\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m \u001b[43me\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/_tensor.py:522\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    512\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    513\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    514\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    515\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    520\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    521\u001b[0m     )\n\u001b[0;32m--> 522\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    523\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    524\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/lib/python3.10/site-packages/torch/autograd/__init__.py:266\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    261\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    263\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    264\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    265\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 266\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    267\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    268\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    269\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    270\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    273\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    274\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward."
     ]
    }
   ],
   "source": [
    "a = Variable(torch.rand(1, 4), requires_grad=True)\n",
    "b = a**2\n",
    "c = b*2\n",
    "\n",
    "d = c.mean()\n",
    "e = c.sum()\n",
    "\n",
    "\n",
    "d.backward()\n",
    "\n",
    "# RuntimeError: Trying to backward through the graph a second time\n",
    "e.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e6734e1b-6f8a-4b70-8448-6b016ad46599",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = Variable(torch.rand(1, 4), requires_grad=True)\n",
    "b = a**2\n",
    "c = b*2\n",
    "\n",
    "d = c.mean()\n",
    "e = c.sum()\n",
    "\n",
    "\n",
    "d.backward(retain_graph=True)\n",
    "\n",
    "e.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0cbf9a6-1096-427a-842c-8f1810897aba",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "&b_i=a_i^2\\\\\n",
    "&c_i=2b_i=2a_i^2\\\\\n",
    "&d=\\frac{\\sum_ic_i}4=\\frac{\\sum_i 2a_i^2}4\\\\\n",
    "&e=\\sum_i c_i=\\sum_i 2a_i^2\n",
    "\\end{split}\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\begin{split}\n",
    "&\\frac{\\partial d}{\\partial a_i}=a_i\\\\\n",
    "&\\frac{\\partial e}{\\partial a_i}=4a_i\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "98b58218-6feb-4996-a48c-8828367ede78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3904, 0.6009, 0.2566, 0.7936]], requires_grad=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "560bf7b0-d643-4f22-8404-b6a44c293627",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.9522, 3.0045, 1.2829, 3.9682]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7fda9aa3-e12a-436f-8459-5b3e32cf9d78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.9522, 3.0045, 1.2829, 3.9682]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "5*a"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "624f5d5d-100b-4c17-ac75-dfc4c791bc04",
   "metadata": {},
   "source": [
    "- suppose you first back-propagate loss1, then loss2 (you can also do the reverse)\n",
    "\n",
    "```\n",
    "l1.backward(retain_graph=True)\n",
    "l2.backward() # now the graph is freed, and next process of batch gradient descent is ready\n",
    "\n",
    "optimizer.step() # update the network parameters\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f61fa935-28dd-4253-a2a8-b2a93c45167a",
   "metadata": {},
   "source": [
    "## non-leaf node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c809ffe4-a450-4931-a79b-cfa028befc10",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = Variable(torch.rand(1, 4), requires_grad=True)\n",
    "b = a**2\n",
    "c = b*2\n",
    "\n",
    "d = c.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "de3cc5ba-6b79-4a5e-a929-3f446e4bddcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "d.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ac024ac4-4396-4814-a86c-1d1d2708e200",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_10679/3238518479.py:1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)\n",
      "  b.grad\n"
     ]
    }
   ],
   "source": [
    "b.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a8031a07-97f4-40c3-9237-12adb097b2c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8851, 0.0177, 0.8735, 0.3523]], grad_fn=<PowBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "10fbdf8c-7700-4d11-be9a-63db21f97431",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.is_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "df991db1-c5d5-43d6-96a8-5fd224823088",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.5000, 0.5000, 0.5000, 0.5000]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = Variable(torch.rand(1, 4), requires_grad=True)\n",
    "b = a**2\n",
    "b.retain_grad()\n",
    "c = b*2\n",
    "\n",
    "d = c.mean()\n",
    "d.backward()\n",
    "b.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df77be8b-d141-4a03-955f-3b133a22d10f",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "&d = \\frac{\\sum_i c_i}{4}=\\frac{\\sum_i 2b_i}{4}=\\frac{\\sum_i b_i}2\\\\\n",
    "&\\frac{\\partial d}{\\partial b_i}=\\frac12\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "398633c8-d513-42c7-97ea-65c7800f8910",
   "metadata": {},
   "source": [
    "### nn 中间层的weights 其实也是 leaf node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5cbc8172-fea4-4b29-83b5-e7ef35289fcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MLP, self).__init__()\n",
    "        self.fc1 = nn.Linear(28*28, 512)\n",
    "        self.fc2 = nn.Linear(512, 512)\n",
    "        self.fc3 = nn.Linear(512, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = nn.Flatten(x)\n",
    "        x = self.fc3(nn.ReLU(self.fc2(nn.ReLU(self.fc1(x)))))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a5d6c274-5da5-43ea-9239-2e6e53aefbe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp = MLP()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "83745b46-d770-45d0-917d-f191ffd85fdc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlp.fc1.weight.is_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2713148d-3255-427a-936c-d92a9a060550",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlp.fc2.weight.is_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "bd731863-b8fe-4f96-946b-c980ae206fde",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlp.fc3.weight.is_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1a1a1db-b7d1-468b-a4f1-69617f675272",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
